{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "name": "model.ipynb",
      "provenance": [],
      "private_outputs": true,
      "collapsed_sections": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.8"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "S89AJpQYG3du",
        "colab": {}
      },
      "source": [
        "import math\n",
        "\n",
        "import tensorflow as tf\n",
        "import tensorflow.keras.layers as layers\n",
        "import tensorflow_datasets as tfds\n",
        "from tensorflow.keras.callbacks import LearningRateScheduler\n",
        "from tensorflow.keras.initializers import Constant\n",
        "from tensorflow.keras.models import Model"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Bpe2e0QhvLKX",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "BATCH_SIZE = 128\n",
        "SAVED_MODEL_DIR = './saved_model'"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "fc96HiziSiOR",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "(ds_train_data, ds_val_data), info = tfds.load(\n",
        "    name='mnist',\n",
        "    split=['train', 'test'],\n",
        "    with_info=True,\n",
        "    as_supervised=True,\n",
        ")\n",
        "\n",
        "num_classes = info.features['label'].num_classes"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "x8N_qOpgSlLG",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def preprocess(image, label):\n",
        "    image = tf.cast(image, tf.float32)\n",
        "    image = image / 255.0\n",
        "    return image, label\n",
        "\n",
        "AUTOTUNE = tf.data.experimental.AUTOTUNE\n",
        "\n",
        "ds_train = (\n",
        "    ds_train_data\n",
        "    .map(preprocess, num_parallel_calls=AUTOTUNE)\n",
        "    .cache()\n",
        "    .shuffle(info.splits['train'].num_examples)\n",
        "    .batch(BATCH_SIZE)\n",
        "    .prefetch(AUTOTUNE)\n",
        ")\n",
        "\n",
        "ds_val = (\n",
        "    ds_val_data\n",
        "    .map(preprocess, AUTOTUNE)\n",
        "    .batch(BATCH_SIZE)\n",
        "    .cache()\n",
        "    .prefetch(AUTOTUNE)\n",
        ")"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "_Lq0YDUYiTMN",
        "colab": {}
      },
      "source": [
        "inputs = layers.Input(shape=(28, 28, 1), name='input')\n",
        "\n",
        "x = layers.Conv2D(24, kernel_size=(6, 6), strides=1)(inputs)\n",
        "x = layers.BatchNormalization(scale=False, beta_initializer=Constant(0.01))(x)\n",
        "x = layers.Activation('relu')(x)\n",
        "x = layers.Dropout(rate=0.25)(x)\n",
        "\n",
        "x = layers.Conv2D(48, kernel_size=(5, 5), strides=2)(x)\n",
        "x = layers.BatchNormalization(scale=False, beta_initializer=Constant(0.01))(x)\n",
        "x = layers.Activation('relu')(x)\n",
        "x = layers.Dropout(rate=0.25)(x)\n",
        "\n",
        "x = layers.Conv2D(64, kernel_size=(4, 4), strides=2)(x)\n",
        "x = layers.BatchNormalization(scale=False, beta_initializer=Constant(0.01))(x)\n",
        "x = layers.Activation('relu')(x)\n",
        "x = layers.Dropout(rate=0.25)(x)\n",
        "\n",
        "x = layers.Flatten()(x)\n",
        "x = layers.Dense(200)(x)\n",
        "x = layers.BatchNormalization(scale=False, beta_initializer=Constant(0.01))(x)\n",
        "x = layers.Activation('relu')(x)\n",
        "x = layers.Dropout(rate=0.25)(x)\n",
        "\n",
        "predications = layers.Dense(num_classes, activation='softmax', name='output')(x)\n",
        "\n",
        "model = Model(inputs=inputs, outputs=predications)\n",
        "model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy'])"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "nBzYWAEAiwzx",
        "colab": {}
      },
      "source": [
        "model.summary()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "7L0zZdYRw3C_",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "lr_decay = lambda epoch: 0.0001 + 0.02 * math.pow(1.0 / math.e, epoch / 3.0)\n",
        "decay_callback = LearningRateScheduler(lr_decay, verbose=1)\n",
        "\n",
        "model.fit(\n",
        "    ds_train,\n",
        "    epochs=20,\n",
        "    validation_data=ds_val,\n",
        "    callbacks=[decay_callback],\n",
        "    verbose=1\n",
        ")"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "SwCAPICrxmRc",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "tf.saved_model.save(model, SAVED_MODEL_DIR)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "9XVL5ULexulp",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "converter = tf.lite.TFLiteConverter.from_saved_model(SAVED_MODEL_DIR)\n",
        "# converter.optimizations = [tf.lite.Optimize.DEFAULT]\n",
        "# converter.target_spec.supported_types = [tf.float16]\n",
        "tflite_model = converter.convert()\n",
        "\n",
        "with open('mnist.tflite', 'wb') as f:\n",
        "    f.write(tflite_model)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "51PTkdoPDOTW",
        "colab": {}
      },
      "source": [
        "try:\n",
        "    from google.colab import files\n",
        "    files.download('mnist.tflite')\n",
        "except:\n",
        "    print(\"Skip downloading\")"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}